{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizaing the predicate shifts\n",
    "\n",
    "In the paper, we visualize all the predicate shifts that we learn. This notebook takes you through the process of creating such shifts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from utils.visualization_utils import get_att_map, objdict, get_dict\n",
    "from scipy.stats import multivariate_normal\n",
    "import keras.backend as K\n",
    "\n",
    "import numpy as np\n",
    "import os\n",
    "from PIL import Image\n",
    "import json\n",
    "import h5py\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Let's create initial attention.\n",
    "im_width = 14\n",
    "im_height = 14\n",
    "\n",
    "def create_gaussian(center):\n",
    "    xlim = (-2, 2)\n",
    "    ylim = (-2, 2)\n",
    "    kernel = multivariate_normal(mean=center, cov=np.eye(2))\n",
    "    x = np.linspace(xlim[0], xlim[1], im_width)\n",
    "    y = np.linspace(ylim[0], ylim[1], im_height)\n",
    "    xx, yy = np.meshgrid(x,y)\n",
    "    xxyy = np.c_[xx.ravel(), yy.ravel()]\n",
    "    zz = kernel.pdf(xxyy)\n",
    "    in_att = zz.reshape((im_height, im_width))\n",
    "    return in_att\n",
    "\n",
    "in_att = create_gaussian((0, 0))\n",
    "plt.imshow(in_att, interpolation='spline16')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Choose the dataset we want to visualize the predicates for."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "###################\n",
    "data_type = \"visualgenome\"\n",
    "###################\n",
    "if data_type==\"vrd\":\n",
    "    nrows=7\n",
    "    ncols=10\n",
    "    figsize = (14,20)\n",
    "    sym_ssn_checkpoint = \"pretrained/vrd.h5\"\n",
    "    vocab_dir = os.path.join('data/VRD')\n",
    "elif data_type==\"clevr\":\n",
    "    nrows=3\n",
    "    ncols=4\n",
    "    figsize = (12,6)\n",
    "    sym_ssn_checkpoint = \"pretrained/clevr.h5\"\n",
    "    vocab_dir = os.path.join('data/Clevr/')\n",
    "elif data_type==\"visualgenome\":\n",
    "    nrows=7\n",
    "    ncols=10\n",
    "    figsize = (14,20)\n",
    "    ssn_checkpoint = \"\"\n",
    "    sym_ssn_checkpoint = \"pretrained/visualgenome.h5\"\n",
    "    vocab_dir = os.path.join('data/VisualGenome/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Grab all the weights\n",
    "predicate_dict, obj_subj_dict = get_dict(vocab_dir)\n",
    "model_weights = h5py.File(sym_ssn_checkpoint)\n",
    "params = objdict(json.load(open(os.path.join(os.path.dirname(sym_ssn_checkpoint), \"args.json\"), \"r\")))\n",
    "conv_filters = {}\n",
    "inv_conv_filters = {}\n",
    "for i in range(params.num_predicates):\n",
    "    predicate = predicate_dict[i]\n",
    "    conv_filters[predicate] = []\n",
    "    inv_conv_filters[predicate] = []\n",
    "    for j in range(params.nb_conv_att_map):\n",
    "        if 'conv0-predicate0' in model_weights:\n",
    "            conv_weights_name = \"conv{}-predicate{}\".format(j, i)\n",
    "            inv_conv_weights_name = \"conv{}-inv-predicate{}\".format(j, i)\n",
    "        else:\n",
    "            conv_weights_name = \"conv{}-predicate{}-0\".format(j, i)\n",
    "            inv_conv_weights_name = \"conv{}-predicate{}-1\".format(j, i)\n",
    "        if j == 0:\n",
    "            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]\n",
    "            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]\n",
    "        else:\n",
    "            conv_filters[predicate] += [model_weights[conv_weights_name][conv_weights_name]['kernel:0'][()]]\n",
    "            inv_conv_filters[predicate] += [model_weights[inv_conv_weights_name][inv_conv_weights_name]['kernel:0'][()]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Before we continue, let's visualize just one of them to make sure that everything works.\n",
    "\n",
    "Make sure the 'above' is actually a predicate in the dataset you are visualizing. Otherwise, type in a different predicate here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "###################\n",
    "predicate = \"above\"\n",
    "###################\n",
    "sess = tf.InteractiveSession()\n",
    "att = in_att.reshape(1, im_height, im_width, 1)\n",
    "att = K.constant(att)\n",
    "for j in range(params.nb_conv_att_map):\n",
    "    kernel = np.array(conv_filters[predicate][j])\n",
    "    att = K.conv2d(att, kernel, padding='same', data_format='channels_last')\n",
    "    att = K.relu(att)\n",
    "att = K.sum(att, axis=3)\n",
    "att = att.eval().reshape((im_height, im_width))\n",
    "sess.close()\n",
    "plt.imshow(att, interpolation='gaussian')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's compute the shifts for all the predicates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Compute all the attentions\n",
    "shifts = {}\n",
    "inv_shifts = {}\n",
    "in_att = in_att.reshape(1, im_height, im_width, 1)\n",
    "sess = tf.InteractiveSession()\n",
    "for i in range(params.num_predicates):\n",
    "    att = K.constant(in_att)\n",
    "    inv_att = K.constant(in_att)\n",
    "    predicate = predicate_dict[i]\n",
    "    for j in range(params.nb_conv_att_map):\n",
    "        kernel = np.array(conv_filters[predicate][j])\n",
    "        att = K.conv2d(att, kernel, padding='same', data_format='channels_last')\n",
    "        att = K.relu(att)\n",
    "        inv_kernel = np.array(inv_conv_filters[predicate][j])\n",
    "        inv_att = K.conv2d(inv_att, inv_kernel, padding='same', data_format='channels_last')\n",
    "        inv_att = K.relu(inv_att)\n",
    "    att = K.sum(att, axis=3)\n",
    "    att = att.eval().reshape((im_height, im_width))\n",
    "    inv_att = K.sum(inv_att, axis=3)\n",
    "    inv_att = inv_att.eval().reshape((im_height, im_width))\n",
    "    shifts[predicate] = att\n",
    "    inv_shifts[predicate] = inv_att\n",
    "sess.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Now let's visualize all of them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "######################\n",
    "interp_method = 'spline16'\n",
    "######################\n",
    "# plot all the shifts\n",
    "fig, axs = plt.subplots(nrows=nrows*2, ncols=ncols, figsize=figsize)\n",
    "fig.tight_layout()\n",
    "row = 0\n",
    "col = 0\n",
    "for idx in range(params.num_predicates):\n",
    "    ax = axs[row, col]\n",
    "    predicate = predicate_dict[idx]\n",
    "    im = shifts[predicate]\n",
    "    plot = ax.imshow(im, interpolation=interp_method)\n",
    "    ax.set_title(predicate)\n",
    "    ax.axis(\"off\")\n",
    "    ax = axs[row, col+1]\n",
    "    im = inv_shifts[predicate]\n",
    "    plot = ax.imshow(im, interpolation=interp_method)\n",
    "    ax.set_title(\"INV {}\".format(predicate))\n",
    "    ax.axis(\"off\")\n",
    "    col += 2\n",
    "    if col >= ncols:\n",
    "        row += 1\n",
    "        col = 0\n",
    "for row in range(nrows*2):\n",
    "    for col in range(ncols):\n",
    "        ax = axs[row, col]\n",
    "        ax.axis(\"off\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
